{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Knet RNN example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# After installing and starting Julia run the following to install the required packages:\n",
    "# julia> Pkg.init(); for p in (\"CUDAdrv\",\"IJulia\",\"Knet\",\"PyCall\",\"JLD2\"); Pkg.add(p); end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mChecking out Knet ilkarman...\n",
      "\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mPulling Knet latest ilkarman...\n",
      "\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mNo packages to install, update or remove\n",
      "\u001b[39m\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mBuilding Knet\n",
      "\u001b[39m"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make: 'libknet8.so' is up to date.\n"
     ]
    }
   ],
   "source": [
    "Pkg.checkout(\"Knet\",\"ilkarman\") # make sure we have the right Knet version\n",
    "Pkg.build(\"Knet\")\n",
    "using Knet\n",
    "True=true # so we can read the python params\n",
    "include(\"common/params_lstm.py\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS: Linux\n",
      "Julia: 0.6.1\n",
      "Knet: 0.8.5+\n",
      "GPU: Tesla K80\n",
      "\n"
     ]
    }
   ],
   "source": [
    "println(\"OS: \", Sys.KERNEL)\n",
    "println(\"Julia: \", VERSION)\n",
    "println(\"Knet: \", Pkg.installed(\"Knet\"))\n",
    "println(\"GPU: \", readstring(`nvidia-smi --query-gpu=name --format=csv,noheader`))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define model\n",
    "function initmodel()\n",
    "    rnnSpec,rnnWeights = rnninit(EMBEDSIZE,NUMHIDDEN; rnnType=:gru)\n",
    "    inputMatrix = KnetArray(xavier(Float32,EMBEDSIZE,MAXFEATURES))\n",
    "    outputMatrix = KnetArray(xavier(Float32,2,NUMHIDDEN))\n",
    "    return rnnSpec,(rnnWeights,inputMatrix,outputMatrix)\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define loss and its gradient\n",
    "function predict(weights, inputs, rnnSpec)\n",
    "    rnnWeights, inputMatrix, outputMatrix = weights # (1,1,W), (X,V), (2,H)\n",
    "    indices = hcat(inputs...)' # (B,T)\n",
    "    rnnInput = inputMatrix[:,indices] # (X,B,T)\n",
    "    rnnOutput = rnnforw(rnnSpec, rnnWeights, rnnInput)[1] # (H,B,T)\n",
    "    return outputMatrix * rnnOutput[:,:,end] # (2,H) * (H,B) = (2,B)\n",
    "end\n",
    "\n",
    "loss(w,x,y,r)=nll(predict(w,x,r),y)\n",
    "lossgradient = grad(loss);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mLoading IMDB...\n",
      "\u001b[39m"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  9.804228 seconds (15.66 M allocations: 821.100 MiB, 4.36% gc time)\n",
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n",
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n"
     ]
    }
   ],
   "source": [
    "# load data\n",
    "include(Knet.dir(\"data\",\"imdb.jl\"))\n",
    "@time (xtrn,ytrn,xtst,ytst,imdbdict)=imdb(maxlen=MAXLEN,maxval=MAXFEATURES)\n",
    "for d in (xtrn,ytrn,xtst,ytst); println(summary(d)); end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare for training\n",
    "weights = nothing; knetgc(); # Reclaim memory from previous run\n",
    "rnnSpec,weights = initmodel()\n",
    "optim = optimizers(weights, Adam; lr=LR, beta1=BETA_1, beta2=BETA_2, eps=EPS);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  3.525070 seconds (1.49 M allocations: 80.569 MiB, 0.89% gc time)\n"
     ]
    }
   ],
   "source": [
    "# force precompile (optional)\n",
    "(x,y) = first(minibatch(xtrn,ytrn,BATCHSIZE))\n",
    "@time lossgradient(weights,x,y,rnnSpec);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mTraining...\n",
      "\u001b[39m"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 10.559463 seconds (668.98 k allocations: 61.830 MiB, 4.87% gc time)\n",
      "  9.602975 seconds (358.03 k allocations: 44.377 MiB, 6.43% gc time)\n",
      "  9.590503 seconds (358.76 k allocations: 44.389 MiB, 6.38% gc time)\n",
      " 29.754309 seconds (1.39 M allocations: 150.835 MiB, 5.86% gc time)\n"
     ]
    }
   ],
   "source": [
    "# 30s\n",
    "info(\"Training...\")\n",
    "@time for epoch in 1:EPOCHS\n",
    "    @time for (x,y) in minibatch(xtrn,ytrn,BATCHSIZE;shuffle=true)\n",
    "        grads = lossgradient(weights,x,y,rnnSpec)\n",
    "        update!(weights, grads, optim)\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mTesting...\n",
      "\u001b[39m"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  3.055123 seconds (289.27 k allocations: 45.848 MiB, 4.54% gc time)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.8529647435897436"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "info(\"Testing...\")\n",
    "total = correct = 0\n",
    "@time for (x,y) in minibatch(xtst,ytst,BATCHSIZE)\n",
    "    total += accuracy(predict(weights,x,rnnSpec), y; average=false)\n",
    "    correct += length(y)\n",
    "end\n",
    "total/correct"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 0.6.1",
   "language": "julia",
   "name": "julia-0.6"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "0.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
